# ---
# jupyter:
#   jupytext:
#     cell_metadata_filter: -all
#     custom_cell_magics: kql
#     text_representation:
#       extension: .py
#       format_name: percent
#       format_version: '1.3'
#       jupytext_version: 1.11.2
#   kernelspec:
#     display_name: lna_env
#     language: python
#     name: python3
# ---

# %% [markdown]
# # Classify sentences with Bert

# %% [markdown]
# ### Imports

# %%
import csv
import pandas as pd
import os
import glob
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, SequentialSampler
from tqdm import tqdm

from bert_preprocess import BertPreprocessor
from config import BATCH_SIZE, LABEL_VALUES


# %%
class SentimentPredictor:
    """
    Interface for interacting with the fine-tunded model
    """

    def __init__(self):
        """
        set device, init BertPreprocessor and get latest model
        """
        self.device = torch.device("cpu")
        self.bp = BertPreprocessor()
        self.latest_model_filename = self._get_latest_model_filename()
        # init model for fallback
        self.model = None

    def _get_latest_model_filename(self):
        """
        get the latest file with an .pt ending
        """
        list_of_files = glob.glob(f"fine-tuned-model.pt")
        return max(list_of_files, key=os.path.getctime)

    def load_model(self):
        """
        load the latest model for later usage
        """
        #print("Loading model...")
        self.model = torch.load("fine-tuned-model.pt", map_location=torch.device("cpu"))
        # self.model = torch.load("fine-tuned-model.pt")

    def _prettify_probabilities(self, probabilities: list, shorten=False) -> list:
        """
        get the index with the highest prob and return corresponding label
        """
        #print(probabilities)

        cut_index = len(LABEL_VALUES[np.argmax(LABEL_VALUES)])
        if shorten:
            cut_index = 3
        return [
            LABEL_VALUES[probabilities.argmax()][:cut_index]
            # for prob_list in probabilities
        ]
        return probabilities

    def _get_probabilies(self, dataloader):
        """
        Perform a forward pass on the trained BERT model to predict
        probabilities on the set.
        The probabilities for one element come as a 3-list where
        the index of the probability-list corresponds to the index of
        the label of the
        LABEL_VALUES = ["positive", "neutral", "negative"]
        """
        # Put the model into the evaluation mode,
        # the dropout layers are disabled.
        self.model.eval()

        all_logits = []
        predictions, true_labels = [], []
        # For each batch in our test set...
        for batch in dataloader:
            # Load batch to device(CPU)
            b_input_ids, b_attn_mask = tuple(t.to(self.device) for t in batch)[:2]

            # Compute logits
            with torch.no_grad():
                logits = self.model(b_input_ids, b_attn_mask)
            all_logits.append(logits)

        all_logits = all_logits[0]
        all_logits = all_logits.logits

        # Concatenate logits from each batch
        all_logits = torch.cat(tuple(all_logits), dim=0)

        # Apply softmax to calculate probabilities
        probs = F.softmax(all_logits, dim=0).cpu().numpy()

        return probs

    def _make_predictable(self, segment_list: list) -> DataLoader:
        """
        takes list of segments and returns a dataloader which has to be
        used for the predict function
        """
        # from bert_preprocess
        _, padding_token_ids, attention_masks = self.bp.preprocess(
            slim=True, segments=segment_list
        )

        #print(torch.tensor(padding_token_ids))

        data = TensorDataset(
            torch.tensor(padding_token_ids), torch.tensor(attention_masks)
        )

        return DataLoader(data, sampler=SequentialSampler(data), batch_size=BATCH_SIZE)

    def predict(self, segment_list: list, pretty=True, shorten=False, verbose=False) -> list:
        """
        predict the sentiment of each element in segment_list
        pretty prints by default.
        shorten does only effect the output if pretty=True
        """
        dataloader = self._make_predictable(segment_list)
        if self.model is None:
            self.load_model()
        probabilities = self._get_probabilies(dataloader)
        if verbose:
            print("#################")
            print(probabilities)
            print(self._prettify_probabilities(probabilities, shorten=shorten))
        if pretty:
            return self._prettify_probabilities(probabilities, shorten=shorten)
        return probabilities


# %% [markdown]
# ### Load data

# %% [markdown]
# Input data must be list of lists e.g., [["J'aime le chocolat"],["Je suis une chaise"]]

# %%
df = pd.read_csv("input-bert_sentence.csv", header=None,names=['sentence'])
df["sentence"] = df["sentence"].apply(lambda x: [x])

# %%
l_sentences = df.sentence.tolist()

# %%
p1 = SentimentPredictor()

# %%
# Transform and predict
print("Started transformation...")
l_predicted = []
for sentence in tqdm(l_sentences):
    # Change pretty to True to get categorisation
    b = p1.predict(sentence, pretty=False)
    l_predicted.append(b)

# %%
#Check classification
pd.DataFrame(l_predicted)

# %%
# Save into
file = open("output-bert_sentence.csv", "w+", newline="")
print(f"Saving to {file}")
with file:
    write = csv.writer(file)
    write.writerows(l_predicted)

# %%
